'''
Proxy api for handling student questions and streaming responses.
This API serves as a proxy to both model_server and sql_server, and handles streaming responses.
This proxy is necessary for handling the order of requests and formatting of final responses. 
Designed to handle multiple types of requests, including standard and streaming.
Through async, it is able to pause execution and wait for responses without pausing operation.
This allows for efficient handling of long-running requests without blocking the server.

To run for debugging:
CUDA_VISIBLE_DEVICES=1 python /scratch/dhoward/Chatbot/api.py

For server running:
sudo supervisorctl start / stop / status myadvisor-chatbot-api

API Local Port Keys
8443: Main API Server
8001: Model Server
8002: SQL Server
8004: FAISS Server

Note: The system is built using python 3.10.12 in a conda environment.
Python 3.10.12 is still installed within conda and is required for compatibility with most dependencies.
'''

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import httpx
import json
import requests
from typing import List, Optional
from datetime import datetime

app = FastAPI(
    title="Student Question API",
    description="API for answering student questions using a hosted LLM",
    version="1.0.2"
)
# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update with myadvisor domain in production
    allow_credentials=True,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)

# Addresses of the services
SERVICES = { 
    'chatbot': 'http://137.158.59.19:8001',   
    'sql': 'http://137.158.59.19:8002',        
    'faiss': 'http://137.158.59.19:8004'       
}

API_URL = "https://myadvisor.cs.uct.ac.za/backend/"

# Pydantic models for data passing and validation
class ContextItem(BaseModel):
    question: str
    answer: str

class InferenceRequest(BaseModel):
    question: str
    context: Optional[List[ContextItem]] = []

class RoadmapEvaluationResponse(BaseModel):
    evaluation: str
    success: bool
    
class RoadmapJsonRequest(BaseModel):
    roadmap_data: dict  # Raw JSON data from the roadmap

class PastCourse(BaseModel):
    code: str
    name: str
    passed: bool

    class Config: 
        allow_population_by_field_name = True

class NextCourse(BaseModel):
    code: str
    name: str
    entry_requirements: str
    corequisites: str
    restrictions: Optional[str] = None

class EvaluationRequest(BaseModel): 
    degree_name: str
    degree_code: str
    notes: str
    past_courses: List[PastCourse]
    next_courses: List[NextCourse]

def extract_evaluation_data(json_data: dict) -> dict:
    '''
    Extract the data from the roadmap.json that matches the EvaluationRequest.
    Used to experiment with how much data is needed for an optimal roadmap evaluation.
    Will be removed in production when frontend is changed.
    '''
    degree_info = json_data.get("degree_info", {})
    degree_name = degree_info.get("degree_name", "")
    degree_code = degree_info.get("degree_code", "")
    notes = degree_info.get("degree_notes", "")
    past_courses = []
    completed_courses = json_data.get("completed_courses_details", [])
    for course in completed_courses:
        past_course = {
            "code": course.get("code") or "",
            "name": course.get("name") or "Course not found",
            "passed": course.get("passed", False)
        }
        past_courses.append(past_course)
    next_courses = []
    next_year_courses = json_data.get("next_year_courses", [])
    for course in next_year_courses:
        next_course = {
            "code": course.get("code") or "",
            "name": course.get("name") or "",
            "entry_requirements": course.get("course_entry_requirements") or "",
            "corequisites": course.get("corequisites", "") or "",
            "restrictions": course.get("restrictions")  # This will be None if not present
        }
        next_courses.append(next_course)
    # Create the final extraction result
    evaluation_request_data = {
        "degree_name": degree_name,
        "degree_code": degree_code,
        "notes": notes,
        "past_courses": past_courses,
        "next_courses": next_courses
    }
    
    return evaluation_request_data

@app.get("/health")
async def health_check():
    async with httpx.AsyncClient() as client:
        try:
            response = await client.get(f"{SERVICES['chatbot']}/health")  # Uses localhost:63240
            return response.json()
        except httpx.RequestError:
            return {"status": "model server unavailable", "model_loaded": False}

@app.get("/status")
async def status_check(): 
    try:
        async with httpx.AsyncClient() as client:
            response = await client.get(f"{SERVICES['chatbot']}/status")  # Uses localhost:8003
            response.raise_for_status()
            return response.json()
    except httpx.HTTPStatusError as e:
        raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
    except httpx.RequestError as e:
        raise HTTPException(status_code=503, detail="Model server unavailable")

@app.post("/ask/stream-json")
async def ask_question_rag_stream_json(request: InferenceRequest):
    """
    RAG (Retrieval-Augmented Generation) streaming endpoint
    
    This endpoint orchestrates multiple API calls to provide context-enhanced responses:
    1. Searches FAISS vector database for similar question IDs
    2. Queries SQL database for additional context
    3. Fetches full question details using the returned question IDs
    4. Streams the enhanced response from the model server
    
    The workflow:
    - Get similar question IDs from FAISS based on semantic similarity
    - Query database for relevant data
    - Retrieve full question details using the IDs for efficient lookup
    - Combine all context and stream the AI response
    """
    async def rag_json_stream():
        context_data = {
            "database_results": None,
            "full_questions": [],
            "search_question_ids": []
        }
        async with httpx.AsyncClient(timeout=300.0) as client:
            try:     
                # Step 1: Search FAISS for similar question IDs
                faiss_response = await client.post(
                    f"{SERVICES['chatbot']}/search",
                    json={
                        "question": request.question,
                        "k": 5,
                        "max_results": 2,
                        "min_similarity": 0.2,
                    }
                )
                faiss_response.raise_for_status()
                faiss_data = faiss_response.json()
                context_data["search_question_ids"] = faiss_data.get("indices", [])
                print("Vector search success")
                # Step 2: Query database for additional context  
                max_rows = 100
                try:
                    db_response = await client.post(
                        f"{SERVICES['sql']}/query",  # Use the SQL service address
                        json={
                            "question": request.question, 
                            "max_rows": max_rows,
                            "debug": False
                        }, 
                        timeout = 15.0
                    )
                    db_response.raise_for_status()
                    context_data["database_results"] = db_response.json()
                    print("Database query successful")
                except httpx.HTTPStatusError as e:
                    # Database query failed with HTTP error - continue without it
                    print(f"Database query failed with HTTP error {e.response.status_code}: {e.response.text}")
                    context_data["database_results"] = {"error": f"Database query failed: HTTP {e.response.status_code}"}
                except httpx.RequestError as e:
                    # Database connection/request error - continue without it
                    print(f"Database connection failed: {str(e)}")
                    context_data["database_results"] = {"error": f"Database unavailable: {str(e)}"}
                except Exception as e:
                    # Any other database-related error - continue without it
                    print(f"Unexpected database error: {str(e)}")
                    context_data["database_results"] = {"error": f"Database error: {str(e)}"}
                
                # Step 3: Get full questions from external API using question IDs
                print("Attempting to fetch questions by ids from the database.")
                if context_data["search_question_ids"]:
                    try:
                        # Fetch question details by IDs for more efficient lookup
                        questions_response = await client.post(
                            f"{API_URL}advisor_questions_fetch/",
                            json={"question_ids": context_data["search_question_ids"]},
                            headers={"Content-Type": "application/json"}, 
                            timeout=15.0
                        )
                        print(f"Advisor questions response status: {questions_response.status_code}")
                        questions_response.raise_for_status()
                        external_data = questions_response.json()
                        context_data["full_questions"] = external_data.get("questions", [])
                        print(f"Retrieved {len(context_data['full_questions'])} similar questions")
                    except httpx.RequestError as e:
                        print(f"External API unavailable: {str(e)}")
                        context_data["full_questions"] = []
                    except Exception as e:
                        print(f"External API error: {str(e)}")
                        context_data["full_questions"] = []
                else:
                    print("No relevant similar questions found, skipping database call.")
                # Transform the full_questions dictionaries into a List[str] for vector info
                vector_info = [
                    f"Example question, do not answer: {q['question']}\n Answer for example question: {q['answer']}"
                    for q in context_data["full_questions"]
                    if isinstance(q, dict) and "question" in q and "answer" in q
                ]
                print(f"Prepared {len(vector_info)} example questions for context")
                # Step 4: Stream response from model server with enhanced context
                print("Generating final response with available context")
                async with client.stream(
                    "POST",
                    f"{SERVICES['chatbot']}/infer/stream-json",
                    json={
                        "question": request.question,
                        "context": "\n".join(
                            f"Q: {item.question}\nA: {item.answer}"
                            for item in (request.context or [])
                        ),
                        "vector_info": vector_info,
                        "database_info": context_data.get("database_results", {}),
                    }
                ) as response:
                    response.raise_for_status()
                    # Then stream the model response
                    async for chunk in response.aiter_text():
                        if chunk.strip():
                            yield chunk
                            
            except httpx.HTTPStatusError as e:
                # Only critical errors (like model server failure) will reach here
                error_data = {
                    "type": "error",
                    "content": f"Critical service error: {e.response.status_code}",
                    "status": "error",
                    "context_gathered": context_data
                }
                yield f"{json.dumps(error_data)}\n"
            except httpx.RequestError as e:
                # Critical service (like model server) unavailable
                error_data = {
                    "type": "error",
                    "content": f"Critical service unavailable: {str(e)}",
                    "status": "error",
                    "context_gathered": context_data
                }
                yield f"{json.dumps(error_data)}\n"
            except Exception as e:
                # Unexpected critical error
                error_data = {
                    "type": "error",
                    "content": f"Unexpected error: {str(e)}",
                    "status": "error",
                    "context_gathered": context_data
                }
                yield f"{json.dumps(error_data)}\n"
    
    return StreamingResponse(
        rag_json_stream(),
        media_type="application/x-ndjson",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "Access-Control-Allow-Origin": "*"
        }
    )


@app.post("/roadmap-evaluation/stream")
async def roadmap_evaluation_stream(request: EvaluationRequest):
    """
    Get an AI roadmap evaluation
    Streaming version not implemented in the frontend yet.
    Future work.
    """
    async def roadmap_json_stream():
        context_data = {
            "database_results": None,
            "full_questions": [],
            "search_question_ids": []
        }        
        async with httpx.AsyncClient(timeout=300.0) as client:
            try:     
                async with client.stream(
                    "POST",
                    f"{SERVICES['chatbot']}/roadmap-evaluation",
                    json={
                        "degree_name": request.degree_name,
                        "degree_code": request.degree_code,
                        "degree_notes": request.notes,
                        "passed_courses": [course.model_dump() for course in request.past_courses if course.passed],
                        "failed_courses": [course.model_dump() for course in request.past_courses if not course.passed],
                        "next_courses": [course.model_dump() for course in request.next_courses],
                    }
                ) as response:
                    response.raise_for_status()
                    # Then stream the model response
                    async for chunk in response.aiter_text():
                        if chunk.strip():
                            yield chunk
                            
            except httpx.RequestError as e:
                error_data = {
                    "type": "error",
                    "content": f"Service unavailable: {str(e)}",
                    "status": "error",
                    "context_gathered": context_data
                }
                yield f"{json.dumps(error_data)}\n"
            except Exception as e:
                error_data = {
                    "type": "error",
                    "content": f"Unexpected error: {str(e)}",
                    "status": "error",
                    "context_gathered": context_data
                }
                yield f"{json.dumps(error_data)}\n"
    
    return StreamingResponse(
        roadmap_json_stream(),
        media_type="application/x-ndjson",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "Access-Control-Allow-Origin": "*"
        }
    )


@app.post("/roadmap-evaluation", response_model=RoadmapEvaluationResponse)
async def roadmap_evaluation(roadmap_data: dict):
    """
    Get an AI roadmap evaluation - returns simple text response
    Current implementation
    """
    try:
        extracted_data = extract_evaluation_data(roadmap_data)

        evaluation_request = EvaluationRequest(**extracted_data)
        async with httpx.AsyncClient(timeout=300.0) as client:
            try:     
                # Make request to model server
                response = await client.post(
                    f"{SERVICES['chatbot']}/infer/evaluation",
                    json={
                        "degree_name": evaluation_request.degree_name,
                        "degree_code": evaluation_request.degree_code,
                        "degree_notes": evaluation_request.notes,
                        "passed_courses": [course.model_dump() for course in evaluation_request.past_courses if course.passed],
                        "failed_courses": [course.model_dump() for course in evaluation_request.past_courses if not course.passed],
                        "next_courses": [course.model_dump() for course in evaluation_request.next_courses],
                    }
                )
                response.raise_for_status()
                
                # Get the JSON response
                result = response.json()
                
                return RoadmapEvaluationResponse(
                    evaluation=result.get("evaluation", "No evaluation generated"),
                    success=result.get("success", True)
                )
                            
            except httpx.RequestError as e:
                return RoadmapEvaluationResponse(
                    evaluation=f"Service unavailable: {str(e)}",
                    success=False
                )
            except Exception as e:
                return RoadmapEvaluationResponse(
                    evaluation=f"Unexpected error: {str(e)}",
                    success=False
                )
                
    except Exception as e:
        return RoadmapEvaluationResponse(
            evaluation=f"Request processing error: {str(e)}",
            success=False
        )
    

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8443)